<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_imagenet_train_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
    This program was used to train the resnet34_1000_imagenet_classifier.dnn
    network used by the <a href="dnn_imagenet_ex.cpp.html">dnn_imagenet_ex.cpp</a> example program.  

    You should be familiar with dlib's DNN module before reading this example
    program.  So read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> first.  
*/</font>



<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iterator<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>thread<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
 
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>using</font> residual <font color='#5555FF'>=</font> add_prev1<font color='#5555FF'>&lt;</font>block<font color='#5555FF'>&lt;</font>N,BN,<font color='#979000'>1</font>,tag1<font color='#5555FF'>&lt;</font>SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>using</font> residual_down <font color='#5555FF'>=</font> add_prev2<font color='#5555FF'>&lt;</font>avg_pool<font color='#5555FF'>&lt;</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,skip1<font color='#5555FF'>&lt;</font>tag2<font color='#5555FF'>&lt;</font>block<font color='#5555FF'>&lt;</font>N,BN,<font color='#979000'>2</font>,tag1<font color='#5555FF'>&lt;</font>SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font><font color='#5555FF'>&gt;</font> <font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> 
<font color='#0000FF'>using</font> block  <font color='#5555FF'>=</font> BN<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'>&lt;</font>BN<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,stride,stride,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;


<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> res       <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual<font color='#5555FF'>&lt;</font>block,N,bn_con,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> ares      <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual<font color='#5555FF'>&lt;</font>block,N,affine,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> res_down  <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual_down<font color='#5555FF'>&lt;</font>block,N,bn_con,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> ares_down <font color='#5555FF'>=</font> relu<font color='#5555FF'>&lt;</font>residual_down<font color='#5555FF'>&lt;</font>block,N,affine,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;


<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level1 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level2 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level3 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,res_down<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> level4 <font color='#5555FF'>=</font> res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,res<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel1 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel2 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel3 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,ares_down<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font> <font color='#0000FF'>using</font> alevel4 <font color='#5555FF'>=</font> ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,ares<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#009900'>// training network type
</font><font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'>&lt;</font>fc<font color='#5555FF'>&lt;</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'>&lt;</font>
                            level1<font color='#5555FF'>&lt;</font>
                            level2<font color='#5555FF'>&lt;</font>
                            level3<font color='#5555FF'>&lt;</font>
                            level4<font color='#5555FF'>&lt;</font>
                            max_pool<font color='#5555FF'>&lt;</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
                            input_rgb_image_sized<font color='#5555FF'>&lt;</font><font color='#979000'>227</font><font color='#5555FF'>&gt;</font>
                            <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#009900'>// testing network type (replaced batch normalization with fixed affine transforms)
</font><font color='#0000FF'>using</font> anet_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'>&lt;</font>fc<font color='#5555FF'>&lt;</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'>&lt;</font>
                            alevel1<font color='#5555FF'>&lt;</font>
                            alevel2<font color='#5555FF'>&lt;</font>
                            alevel3<font color='#5555FF'>&lt;</font>
                            alevel4<font color='#5555FF'>&lt;</font>
                            max_pool<font color='#5555FF'>&lt;</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'>&lt;</font>affine<font color='#5555FF'>&lt;</font>con<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
                            input_rgb_image_sized<font color='#5555FF'>&lt;</font><font color='#979000'>227</font><font color='#5555FF'>&gt;</font>
                            <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
rectangle <b><a name='make_random_cropping_rect_resnet'></a>make_random_cropping_rect_resnet</b><font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
    dlib::rand<font color='#5555FF'>&amp;</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#009900'>// figure out what rectangle we want to crop from the image
</font>    <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>;
    <font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>;
    <font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>;
    <font color='#009900'>// randomly shift the box around
</font>    point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
                 rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>;
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
    matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> crop,
    dlib::rand<font color='#5555FF'>&amp;</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;

    <font color='#009900'>// now crop it out as a 227x227 image.
</font>    <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>img, <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, crop<font face='Lucida Console'>)</font>;

    <font color='#009900'>// Also randomly flip the image
</font>    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
        crop <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop<font face='Lucida Console'>)</font>;

    <font color='#009900'>// And then randomly adjust the colors.
</font>    <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop, rnd<font face='Lucida Console'>)</font>;
<b>}</b>

<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_images'></a>randomly_crop_images</b> <font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
    dlib::array<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> crops,
    dlib::rand<font color='#5555FF'>&amp;</font> rnd,
    <font color='#0000FF'><u>long</u></font> num_crops
<font face='Lucida Console'>)</font>
<b>{</b>
    std::vector<font color='#5555FF'>&lt;</font>chip_details<font color='#5555FF'>&gt;</font> dets;
    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> num_crops; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
        dets.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#BB00BB'>extract_image_chips</font><font face='Lucida Console'>(</font>img, dets, crops<font face='Lucida Console'>)</font>;

    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> img : crops<font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#009900'>// Also randomly flip the image
</font>        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
            img <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;

        <font color='#009900'>// And then randomly adjust the colors.
</font>        <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
    <b>}</b>
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>struct</font> <b><a name='image_info'></a>image_info</b>
<b>{</b>
    string filename;
    string label;
    <font color='#0000FF'><u>long</u></font> numeric_label;
<b>}</b>;

std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> <b><a name='get_imagenet_train_listing'></a>get_imagenet_train_listing</b><font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> images_folder
<font face='Lucida Console'>)</font>
<b>{</b>
    std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> results;
    image_info temp;
    temp.numeric_label <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    <font color='#009900'>// We will loop over all the label types in the dataset, each is contained in a subfolder.
</font>    <font color='#0000FF'>auto</font> subdirs <font color='#5555FF'>=</font> <font color='#BB00BB'>directory</font><font face='Lucida Console'>(</font>images_folder<font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_dirs</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// But first, sort the sub directories so the numeric labels will be assigned in sorted order.
</font>    std::<font color='#BB00BB'>sort</font><font face='Lucida Console'>(</font>subdirs.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, subdirs.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> subdir : subdirs<font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#009900'>// Now get all the images in this label type
</font>        temp.label <font color='#5555FF'>=</font> subdir.<font color='#BB00BB'>name</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> image_file : subdir.<font color='#BB00BB'>get_files</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
        <b>{</b>
            temp.filename <font color='#5555FF'>=</font> image_file;
            results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
        <b>}</b>
        <font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;
    <b>}</b>
    <font color='#0000FF'>return</font> results;
<b>}</b>

std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> <b><a name='get_imagenet_val_listing'></a>get_imagenet_val_listing</b><font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> imagenet_root_dir,
    <font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> validation_images_file 
<font face='Lucida Console'>)</font>
<b>{</b>
    ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font>validation_images_file<font face='Lucida Console'>)</font>;
    string label, filename;
    std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font> results;
    image_info temp;
    temp.numeric_label <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>;
    <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>fin <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> label <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> filename<font face='Lucida Console'>)</font>
    <b>{</b>
        temp.filename <font color='#5555FF'>=</font> imagenet_root_dir<font color='#5555FF'>+</font>"<font color='#CC0000'>/</font>"<font color='#5555FF'>+</font>filename;
        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>temp.filename<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
        <b>{</b>
            cerr <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>file doesn't exist! </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> temp.filename <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
            <font color='#BB00BB'>exit</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
        <b>}</b>
        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>label <font color='#5555FF'>!</font><font color='#5555FF'>=</font> temp.label<font face='Lucida Console'>)</font>
            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;

        temp.label <font color='#5555FF'>=</font> label;
        results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#0000FF'>return</font> results;
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>3</font><font face='Lucida Console'>)</font>
    <b>{</b>
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>To run this program you need a copy of the imagenet ILSVRC2015 dataset and</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>also the file http://dlib.net/files/imagenet2015_validation_images.txt.bz2</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>With those things, you call this program like this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>./dnn_imagenet_train_ex /path/to/ILSVRC2015 imagenet2015_validation_images.txt</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
    <b>}</b>

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\nSCANNING IMAGENET DATASET\n</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_imagenet_train_listing</font><font face='Lucida Console'>(</font><font color='#BB00BB'>string</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font color='#5555FF'>+</font>"<font color='#CC0000'>/Data/CLS-LOC/train/</font>"<font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> number_of_classes <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>back</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.numeric_label<font color='#5555FF'>+</font><font color='#979000'>1</font>;
    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> number_of_classes <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1000</font><font face='Lucida Console'>)</font>
    <b>{</b>
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Didn't find the imagenet dataset. </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
    <b>}</b>
        
    <font color='#BB00BB'>set_dnn_prefer_smallest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;


    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;

    net_type net;
    dnn_trainer<font color='#5555FF'>&lt;</font>net_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>imagenet_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// This threshold is probably excessively large.  You could likely get good results
</font>    <font color='#009900'>// with a smaller value but if you aren't in a hurry this value will surely work well.
</font>    trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>20000</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization
</font>    <font color='#009900'>// stats window to something big too.
</font>    <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>;

    std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> samples;
    std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> labels;

    <font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops.  It's
</font>    <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
</font>    <font color='#009900'>// thread for this kind of data preparation helps us do that.  Each thread puts the
</font>    <font color='#009900'>// crops into the data queue.
</font>    dlib::pipe<font color='#5555FF'>&lt;</font>std::pair<font color='#5555FF'>&lt;</font>image_info,matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>data, <font color='#5555FF'>&amp;</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
    <b>{</b>
        dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>;
        matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> img;
        std::pair<font color='#5555FF'>&lt;</font>image_info, matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> temp;
        <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
        <b>{</b>
            temp.first <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>];
            <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, temp.first.filename<font face='Lucida Console'>)</font>;
            <font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>img, temp.second, rnd<font face='Lucida Console'>)</font>;
            data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
        <b>}</b>
    <b>}</b>;
    std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;

    <font color='#009900'>// The main training loop.  Keep making mini-batches and giving them to the trainer.
</font>    <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-3.
</font>    <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> initial_learning_rate<font color='#5555FF'>*</font><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>
    <b>{</b>
        samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

        <font color='#009900'>// make a 160 image mini-batch
</font>        std::pair<font color='#5555FF'>&lt;</font>image_info, matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> img;
        <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> <font color='#979000'>160</font><font face='Lucida Console'>)</font>
        <b>{</b>
            data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;

            samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>img.second<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
            labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>img.first.numeric_label<font face='Lucida Console'>)</font>;
        <b>}</b>

        trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
</font>    <font color='#009900'>// moving on.
</font>    data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// also wait for threaded processing to stop in the trainer.
</font>    trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

    net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>resnet34.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> net;






    <font color='#009900'>// Now test the network on the imagenet validation dataset.  First, make a testing
</font>    <font color='#009900'>// network with softmax as the final layer.  We don't have to do this if we just wanted
</font>    <font color='#009900'>// to test the "top1 accuracy" since the normal network outputs the class prediction.
</font>    <font color='#009900'>// But this snet object will make getting the top5 predictions easy as it directly
</font>    <font color='#009900'>// outputs the probability of each class as its final output.
</font>    softmax<font color='#5555FF'>&lt;</font>anet_type::subnet_type<font color='#5555FF'>&gt;</font> snet; snet.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Testing network on imagenet validation dataset...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    <font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    <font color='#0000FF'><u>int</u></font> num_right_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    <font color='#0000FF'><u>int</u></font> num_wrong_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// loop over all the imagenet validation images
</font>    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> l : <font color='#BB00BB'>get_imagenet_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
    <b>{</b>
        dlib::array<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> images;
        matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> img;
        <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, l.filename<font face='Lucida Console'>)</font>;
        <font color='#009900'>// Grab 16 random crops from the image.  We will run all of them through the
</font>        <font color='#009900'>// network and average the results.
</font>        <font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> num_crops <font color='#5555FF'>=</font> <font color='#979000'>16</font>;
        <font color='#BB00BB'>randomly_crop_images</font><font face='Lucida Console'>(</font>img, images, rnd, num_crops<font face='Lucida Console'>)</font>;
        <font color='#009900'>// p(i) == the probability the image contains object of class i.
</font>        matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font>,<font color='#979000'>1</font>,<font color='#979000'>1000</font><font color='#5555FF'>&gt;</font> p <font color='#5555FF'>=</font> <font color='#BB00BB'>sum_rows</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font><font color='#BB00BB'>snet</font><font face='Lucida Console'>(</font>images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>num_crops;

        <font color='#009900'>// check top 1 accuracy
</font>        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right_top1;
        <font color='#0000FF'>else</font>
            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong_top1;

        <font color='#009900'>// check top 5 accuracy
</font>        <font color='#0000FF'><u>bool</u></font> found_match <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'>&lt;</font> <font color='#979000'>5</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#0000FF'><u>long</u></font> predicted_label <font color='#5555FF'>=</font> <font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>;
            <font color='#BB00BB'>p</font><font face='Lucida Console'>(</font>predicted_label<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
            <b>{</b>
                found_match <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
                <font color='#0000FF'>break</font>;
            <b>}</b>

        <b>}</b>
        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>found_match<font face='Lucida Console'>)</font>
            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right;
        <font color='#0000FF'>else</font>
            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong;
    <b>}</b>
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>val top5 accuracy:  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>val top1 accuracy:  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> num_right_top1<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right_top1<font color='#5555FF'>+</font>num_wrong_top1<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
<b>{</b>
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>


</pre></body></html>